#include "Mesh.h"
#include "Element.h"
#include <iostream>
#include <string>
#include <cmath>
#include <fstream>

#include <Eigen/Sparse>
#include <Eigen/IterativeLinearSolvers>
typedef Eigen::SparseMatrix<double> SpMat;
typedef Eigen::Triplet<double> Tri;
typedef Eigen::VectorXd RHS;
typedef Eigen::VectorXd Solution;

/** 
 * Real solution. Can be used as Dirichlet boundary conditions.
 * 
 * @param _p variable.
 * 
 * @return function value.
 */
double u(const Point<2, double> &_p);

/** 
 * Source term.
 * 
 * @param _p variable.
 * 
 * @return function value.
 */
double f(const Point<2, double> &_p);

template <typename Mesh_Type, typename Element_Type>
double L2_error(Mesh_Type mesh, Solution solution);

template <typename Mesh_Type, typename Element_Type>
double L2_error(Mesh_Type mesh, Solution solution)
{
    Element_Type ref_ele;
    size_t n_ele = mesh.get_n_grids();
    size_t n_dof = ref_ele.get_n_dofs();
    ///L2error
    ref_ele.read_quad_info("data/triangle.tmp_geo", 6);
    /// Get the quadrature information from the reference element.
    const std::valarray<Point<2, double>> &quad_pnts = ref_ele.get_quad_points();
    const std::valarray<double> &weights = ref_ele.get_quad_weights();
    int n_quad_pnts = ref_ele.get_n_quad_points();
    double volume = ref_ele.get_volume();

    double error = 0.0;
    for (size_t ei = 0; ei < n_ele; ei++)
    {
        /// Find the structure of the current element.
        const Geometry &the_ele = mesh.get_grid(ei);
        /// Find the global indexes of all the degree of freedoms on the
        /// current element.
        const std::valarray<size_t> &dof = the_ele.get_dof();
        /// Within the indexes, we find the coordinates of the dofs
        /// from the mesh, and set to the reference element. Actually
        /// build a mapping between the global dofs and the local
        /// ones. (The indexes and coordinates of local dofs have set
        /// inside P1 element.)
        for (size_t k = 0; k < n_dof; k++)
            ref_ele.set_global_dof(k, mesh.get_dof(dof[k]));

        /// Begin to assemble ...
        for (size_t l = 0; l < n_quad_pnts; l++)
        {
            ///
            double detJ = ref_ele.global2local_Jacobi_det(quad_pnts[l]);
            double u_sol = 0;
            for (size_t i = 0; i < n_dof; i++)
            {
                double phi_i = ref_ele.basis_function(i, quad_pnts[l]);
                u_sol += phi_i * solution[dof[i]];
            }
            error += (u(ref_ele.local2global(quad_pnts[l])) - u_sol) * (u(ref_ele.local2global(quad_pnts[l])) - u_sol) * weights[l] * detJ * volume;
        }
    }
    error = std::sqrt(error);
    std::cout << "L2 error: " << error << std::endl;
    return error;
};

void refinement(std::valarray<Point<2, double>> &_nodes,
                int local_n1_index, int local_n2_index, int local_n3_index,
                size_t new_index[3], int rec, std::valarray<std::valarray<size_t>> &connections);

int main(int argc, char *argv[])
{
    P2Element<double> ref_ele;

    ref_ele.read_quad_info("data/triangle.tmp_geo", 2);

    const std::valarray<Point<2, double>> &quad_pnts = ref_ele.get_quad_points();
    const std::valarray<double> &weights = ref_ele.get_quad_weights();
    int n_quad_pnts = ref_ele.get_n_quad_points();
    double volume = ref_ele.get_volume();

    Easymesh mesh;
    mesh.readData("data/D");
    // mesh.readData("test");
    mesh.build_dofs();

    size_t n_ele = mesh.get_n_grids();
    size_t n_dof = ref_ele.get_n_dofs();
    size_t n_total_dofs = mesh.get_n_dofs();

    SpMat K(n_total_dofs, n_total_dofs);
    RHS Rhs(n_total_dofs);
    Rhs.setZero(n_total_dofs);
    std::vector<Tri> TriList(n_ele * n_dof * n_dof * n_quad_pnts);
    std::vector<Tri>::iterator it = TriList.begin();
    std::cout << "its size:" << TriList.size() << std::endl;

    for (size_t ei = 0; ei < n_ele; ei++)
    {
        const Geometry &the_ele = mesh.get_grid(ei);
        const std::valarray<size_t> &dof = the_ele.get_dof();
        // std::cout << "ele_id: " << the_ele.get_index() << std::endl;
        for (size_t k = 0; k < n_dof; k++)
            ref_ele.set_global_dof(k, mesh.get_dof(dof[k]));

        for (size_t l = 0; l < n_quad_pnts; l++)
        {
            double detJ = ref_ele.global2local_Jacobi_det(quad_pnts[l]);
            Matrix<double> invJ = ref_ele.local2global_Jacobi(quad_pnts[l]);

            for (size_t i = 0; i < n_dof; i++)
            {
                std::valarray<double> grad_i = invJ * ref_ele.basis_gradient(i, quad_pnts[l]);
                double phi_i = ref_ele.basis_function(i, quad_pnts[l]);
                for (size_t j = 0; j < n_dof; j++)
                {
                    std::valarray<double> grad_j = invJ * ref_ele.basis_gradient(j, quad_pnts[l]);
                    double cont = (grad_i[0] * grad_j[0] + grad_i[1] * grad_j[1]) * weights[l] * detJ * volume;
                    // A(dof[i], dof[j]) += cont;
                    *it = Tri(dof[i], dof[j], cont);
                    it++;
                }
                /// rhs(i) = \int f(x_i) phi_i dx
                // rhs[dof[i]] += f(ref_ele.local2global(quad_pnts[l])) * phi_i * weights[l] * detJ * volume;

                Rhs[dof[i]] += f(ref_ele.local2global(quad_pnts[l])) * phi_i * weights[l] * detJ * volume;
            }
        }
    }

    K.setFromTriplets(TriList.begin(), TriList.end());
    K.makeCompressed();

    const std::valarray<Point<2, double>> &global_dofs = mesh.get_dofs();
    for (size_t i = 0; i < n_total_dofs; i++)
    {
        int bm = global_dofs[i].get_boundary_mark();
        if (bm == 0)
            continue;
        else if (bm == 1)
        {
            double boundary_value = u(global_dofs[i]);
            Rhs[i] = boundary_value * K.coeffRef(i, i);
            for (Eigen::SparseMatrix<double>::InnerIterator it(K, i); it; ++it)
            {
                size_t row = it.row();
                if (row == i)
                    continue;
                Rhs[row] -= K.coeffRef(i, row) * boundary_value;
                K.coeffRef(i, row) = 0.0;
                K.coeffRef(row, i) = 0.0;
            }
        }
    };

    Eigen::ConjugateGradient<Eigen::SparseMatrix<double>> Solver_sparse;
    Eigen::VectorXd x1_sparse;
    Solver_sparse.setTolerance(1e-15);
    Solver_sparse.compute(K);
    Solution solution = Solver_sparse.solve(Rhs);

    /// stupid error!
    double error = 0.0;
    for (size_t i = 0; i < n_total_dofs; i++)
        error += (solution[i] - u(global_dofs[i])) * (solution[i] - u(global_dofs[i]));
    error = std::sqrt(error);
    std::cout << "stupid error: " << error << std::endl;
    L2_error<Easymesh, P2Element<double>>(mesh, solution);

    /// Output the solution to file in OpenDx format.
    std::valarray<std::valarray<Point<2, double>>> nodes;
    std::valarray<std::valarray<std::valarray<size_t>>> connections;
    std::valarray<std::valarray<double>> values;

    int rec = 2; //迭代次数
    int side_per_grid = pow(3, rec + 1);
    int node_per_grid = 2 + pow(4, rec); //3+3*(1+4+16+...)
    int ele_per_grid = pow(4, rec);
    nodes.resize(n_ele);
    connections.resize(n_ele);
    values.resize(n_ele);

    size_t node_index[3] = {0, 0, 0}; //local_index,global_index,local_element_index
    for (size_t ei = 0; ei < n_ele; ei++)
    {
        nodes[ei].resize(node_per_grid);
        connections[ei].resize(ele_per_grid);
        values[ei].resize(node_per_grid);

        const Geometry &the_ele = mesh.get_grid(ei);
        const std::valarray<size_t> &dof = the_ele.get_dof();
        node_index[0] = 3;
        node_index[2] = 0;
        for (size_t k = 0; k < n_dof; k++)
        {
            ref_ele.set_global_dof(k, mesh.get_dof(dof[k]));
        }
        for (size_t k = 0; k < 3; k++)
        {
            nodes[ei][k] = mesh.get_dof(dof[k]);
            nodes[ei][k].set_index(node_index[1]++);
        }
        refinement(nodes[ei], 0, 1, 2, node_index, rec, connections[ei]);
        for (size_t k = 0; k < node_per_grid; k++)
        {
            // Matrix<double> invJ = ref_ele.local2global_Jacobi(nodes[ei][k]); //只需要做global2local
            // double temp;
            // temp = invJ(0, 1);
            // invJ(0, 1) = invJ(1, 0);
            // invJ(1, 0) = temp;
            // std::valarray<double> local = invJ * nodes[ei][k];
            // Point<2, double> local_pnt({local[0], local[1]});
            Point<2, double> local_pnt = ref_ele.global2local(nodes[ei][k]);

            double u_sol = 0;
            for (size_t i = 0; i < n_dof; i++)
            {
                double phi_i = ref_ele.basis_function(i, local_pnt);
                u_sol += phi_i * solution[dof[i]];
            }
            values[ei][k] = u_sol;
        }
    }

    std::ofstream output("u.dx");
    output << "object 1 class array type float rank 1 shape 2 item "
           << n_ele * node_per_grid
           << " data follows" << std::endl;
    output.setf(std::ios::fixed);

    output.precision(20);
    for (size_t ei = 0; ei < n_ele; ei++)
    {
        for (size_t k = 0; k < node_per_grid; k++)
        {
            output << nodes[ei][k][0] << "\t" << nodes[ei][k][1] << std::endl;
        }
    }

    output << std::endl;
    output << "object 2 class array type int rank 1 shape 3 item "
           << n_ele * ele_per_grid << " data follows" << std::endl;

    for (size_t ei = 0; ei < n_ele; ei++)
    {
        for (size_t k = 0; k < ele_per_grid; k++)
        {
            output << connections[ei][k][0] << "\t" << connections[ei][k][1] << "\t" << connections[ei][k][2] << "\t" << std::endl;
        }
    }
    output << "attribute \"element type\" string \"triangles\"" << std::endl;
    output << "attribute \"ref\" string \"positions\"" << std::endl;
    output << std::endl;
    output << "object 3 class array type float rank 0 item "
           << n_ele * node_per_grid
           << " data follows" << std::endl;
    for (size_t ei = 0; ei < n_ele; ei++)
    {
        for (size_t k = 0; k < node_per_grid; k++)
        {
            output << values[ei][k] << std::endl;
        }
    }

    output << "attribute \"dep\" string \"positions\"" << std::endl;
    output << std::endl;
    output << "object \"FEMFunction-2d\" class field" << std::endl;
    output << "component \"positions\" value 1" << std::endl;
    output << "component \"connections\" value 2" << std::endl;
    output << "component \"data\" value 3" << std::endl;
    output << "end" << std::endl;

    output.close();

    /// Output the solution to file in OpenDx format.
    output.open("u2.m");
    output << "T=zeros(" << n_ele * ele_per_grid << ","
           << "3);" << std::endl;
    output << "P=zeros(" << n_ele * node_per_grid << ","
           << "3);" << std::endl;

    for (size_t ei = 0; ei < n_ele; ei++)
    {
        for (size_t k = 0; k < node_per_grid; k++)
        {
            output << "P(" << ei * node_per_grid + k + 1 << ",:)=[" << nodes[ei][k][0] << "\t" << nodes[ei][k][1] << "\t" << values[ei][k] << "];" << std::endl;
        }
    }

    for (size_t ei = 0; ei < n_ele; ei++)
    {
        for (size_t k = 0; k < ele_per_grid; k++)
        {
            output << "T(" << ei * ele_per_grid + k + 1 << ",:)=[" << connections[ei][k][0] + 1 << "\t" << connections[ei][k][1] + 1 << "\t" << connections[ei][k][2] + 1 << "];" << std::endl;
        }
    }

    output << "TR=triangulation(T,P(:,1),P(:,2),P(:,3));" << std::endl;
    output << "trisurf(TR);" << std::endl;

    return 0;
};

double u(const Point<2, double> &_p)
{
    return std::sin(M_PI * _p[0]) * std::sin(2.0 * M_PI * _p[1]);
};

double f(const Point<2, double> &_p)
{
    return 5.0 * M_PI * M_PI * u(_p);
};

void refinement(std::valarray<Point<2, double>> &_nodes,
                int local_n1_index, int local_n2_index, int local_n3_index,
                size_t new_index[3], int rec, std::valarray<std::valarray<size_t>> &connections)
{
    if (local_n1_index == local_n2_index || local_n1_index == local_n3_index || local_n2_index == local_n3_index)
        std::cerr << "two equal node index" << std::endl;
    if (rec > 0)
    {
        int new_node1_index = new_index[0]++;
        int new_node2_index = new_index[0]++;
        int new_node3_index = new_index[0]++;

        _nodes[new_node1_index] = mid_point<2, double>(_nodes[local_n2_index], _nodes[local_n3_index]);
        _nodes[new_node1_index].set_index(new_index[1]++);
        _nodes[new_node2_index] = mid_point<2, double>(_nodes[local_n1_index], _nodes[local_n3_index]);
        _nodes[new_node2_index].set_index(new_index[1]++);
        _nodes[new_node3_index] = mid_point<2, double>(_nodes[local_n1_index], _nodes[local_n2_index]);
        _nodes[new_node3_index].set_index(new_index[1]++);

        refinement(_nodes, local_n1_index, new_node2_index, new_node3_index, new_index, rec - 1, connections);
        refinement(_nodes, local_n2_index, new_node1_index, new_node3_index, new_index, rec - 1, connections);
        refinement(_nodes, local_n3_index, new_node1_index, new_node2_index, new_index, rec - 1, connections);
        refinement(_nodes, new_node1_index, new_node2_index, new_node3_index, new_index, rec - 1, connections);
    }
    else
    {
        connections[new_index[2]][0] = _nodes[local_n1_index].get_index();
        connections[new_index[2]][1] = _nodes[local_n2_index].get_index();
        connections[new_index[2]][2] = _nodes[local_n3_index].get_index();
        new_index[2]++;
    }
}
